{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp distributed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *\n",
    "from fastai.callback.progress import ProgressCallback\n",
    "from torch.nn.parallel import DistributedDataParallel, DataParallel\n",
    "from fastai.data.load import _FakeLoader,_loaders\n",
    "from fastai.optimizer import OptimWrapper\n",
    "try: from accelerate import Accelerator\n",
    "except ModuleNotFoundError: pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed training\n",
    "\n",
    "> Callbacks and helper functions to train in parallel or use distributed training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When using multiple GPUs, you will most probably want to fit using distributed training. \n",
    "\n",
    "Example use can be found:\n",
    "\n",
    "- In the form of a script with [examples/distrib.py](https://github.com/fastai/fastai/blob/master/nbs/examples/distrib.py)\n",
    "- Across all the App Examples with the [Notebook Launcher](https://docs.fast.ai/distributed_app_examples.html)\n",
    "- At the bottom of this notebook for more examples with `notebook_launcher`.\n",
    "\n",
    "To use distributed training, there are only three required steps:\n",
    "\n",
    "1. Add `with learn.distrib_ctx():` before your `learn.fit` call\n",
    "2. Either config Accelerate yourself by running `accelerate config` from the command line, or run:\n",
    "```python\n",
    "from accelerate.utils import write_basic_config\n",
    "write_basic_config()\n",
    "```\n",
    "3. Run your training script with `accelerate launch scriptname.py ...args...`\n",
    "\n",
    "\n",
    "If you're using `untar_data`, or may be downloading or uncompressing data or models as part of your script, you should wrap that code with `rank0_first`, which forces that step to occur first just once on the master process, prior to the remaining processes running it in parallel. E.g. instead of:\n",
    "\n",
    "```python\n",
    "path = untar_data(URLs.IMAGEWOOF_320)\n",
    "```\n",
    "\n",
    "...you instead use:\n",
    "\n",
    "```python\n",
    "path = rank0_first(untar_data, URLs.IMAGEWOOF_320)\n",
    "```\n",
    "\n",
    "See below for details on the full API and underlying helper functions, if needed -- however, note that you will not need anything except the above unless you need to change how the distributed training is implemented."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def reset(self: DataParallel):\n",
    "    \"Patch required `reset` call into `DataParallel`\"\n",
    "    if hasattr(self.module, 'reset'): self.module.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class ParallelTrainer(Callback):\n",
    "    \"Wrap a model `DataParallel` automatically\"\n",
    "    run_after,run_before = TrainEvalCallback,Recorder\n",
    "    def __init__(self, device_ids): self.device_ids = device_ids\n",
    "    def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)\n",
    "    def after_fit(self): self.learn.model = self.learn.model.module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def to_parallel(self: Learner, device_ids=None):\n",
    "    \"Add `ParallelTrainer` callback to a `Learner`\"\n",
    "    self.add_cb(ParallelTrainer(device_ids))\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def detach_parallel(self: Learner):\n",
    "    \"Remove `ParallelTrainer` callback from a Learner\"\n",
    "    self.remove_cb(ParallelTrainer)\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@contextmanager\n",
    "def parallel_ctx(self: Learner, device_ids=None):\n",
    "    \"A context manager to adapt a learner to train in data parallel mode.\"\n",
    "    try:\n",
    "        self.to_parallel(device_ids)\n",
    "        yield self\n",
    "    finally: self.detach_parallel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distributed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def reset(self: DistributedDataParallel):\n",
    "    \"Patch required `reset` call into `DistributedDataParallel`\"\n",
    "    if hasattr(self.module, 'reset'): self.module.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def setup_distrib(gpu=None):\n",
    "    \"Setup this process to participate in distributed training\"\n",
    "    if gpu is None: return gpu\n",
    "    gpu = int(gpu)\n",
    "    torch.cuda.set_device(int(gpu))\n",
    "    if num_distrib() > 0: torch.distributed.init_process_group(backend='nccl', init_method='env://')\n",
    "    return gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def teardown_distrib():\n",
    "    \"Free distributed training resources\"\n",
    "    if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DistributedDL(TfmdDL):\n",
    "    \"A `TfmdDL` which splits a batch into equal size pieces for each worker\"\n",
    "    def __init__(self,dl,rank=None,world_size=None,device=None):\n",
    "        if rank is None: rank=rank_distrib()\n",
    "        if world_size is None: world_size=num_distrib()\n",
    "        store_attr()\n",
    "        if type(dl) == torch.utils.data.DataLoader:\n",
    "            shuffle = True if eq(type(dl.sampler), torch.utils.data.RandomSampler) else False\n",
    "            self.dl = DataLoader(dataset=dl.dataset, bs=dl.batch_size, num_workers=dl.num_workers, \\\n",
    "                pin_memory=dl.pin_memory, timeout=dl.timeout, shuffle=shuffle, drop_last=dl.drop_last, persistent_workers=dl.persistent_workers)\n",
    "        self.bs,self.drop_last,self.dataset,fake,self.num_workers,self.offs,self.pin_memory = \\\n",
    "            attrgetter('bs','drop_last','dataset','fake_l','num_workers','offs','pin_memory')(self.dl)\n",
    "        if device is None: self.device = self.dl.device\n",
    "        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, \n",
    "                                  persistent_workers=fake.persistent_workers, \n",
    "                                  pin_memory_device=fake.pin_memory_device)\n",
    "        \n",
    "    def _broadcast(self,t,rank):\n",
    "        \"Broadcasts t from rank `rank` to all other ranks. Returns t so t is same for all ranks after call.\"\n",
    "        t = LongTensor(t).cuda() # nccl only works with cuda tensors\n",
    "        torch.distributed.broadcast(t,rank)\n",
    "        return t.cpu().tolist()\n",
    "\n",
    "    def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test\n",
    "    def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size\n",
    "    def get_idxs(self):\n",
    "        idxs = list(self.dl.get_idxs()) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent)\n",
    "        idxs = self._broadcast(idxs,0)  # broadcast and receive it from rank 0 to all\n",
    "        self.n = len(idxs)              # we assumed n was dl.n but we really care about number of idxs\n",
    "        # add extra samples to make it evenly divisible\n",
    "        self.n_padded = _round_to_multiple(self.n,self.world_size)\n",
    "        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n\n",
    "        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors\n",
    "        return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size]\n",
    "\n",
    "    def before_iter(self):\n",
    "        self.i = 0\n",
    "        self.dl.before_iter()\n",
    "\n",
    "    def randomize(self): self.dl.randomize()\n",
    "    def after_batch(self,b):\n",
    "        self.i += find_bs(b)\n",
    "        return self.dl.after_batch(b)\n",
    "\n",
    "    def after_iter(self):  self.dl.after_iter()\n",
    "    def create_batches(self,samps): return self.dl.create_batches(samps)\n",
    "    def to_detach(self,b, cpu=True, gather=True):\n",
    "        b = self._to_detach(b, cpu, gather)\n",
    "        def _inner(b):\n",
    "            if b.ndim>0:\n",
    "                # for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering\n",
    "                n = sum([min(0,max(-len(b)//self.world_size,\n",
    "                                   self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)])\n",
    "                b = b[:n or None]\n",
    "            return b\n",
    "        return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "_tmp_file = tempfile.NamedTemporaryFile().name\n",
    "# patch _broadcast with a mocked version so we can test DistributedDL w/o a proper DDP setup\n",
    "@patch\n",
    "def _broadcast(self:DistributedDL,t,rank):\n",
    "    t = LongTensor(t)\n",
    "    if rank == self.rank: torch.save(t,_tmp_file)\n",
    "    else:                 t.data = torch.load(_tmp_file)\n",
    "    return t.tolist()\n",
    "# patch _to_detach with a mocked version that will return right gathered size but -100 for other rank tensors\n",
    "@patch\n",
    "def _to_detach(self:DistributedDL,b,cpu=True,gather=True):\n",
    "    b = to_detach(b,cpu,gather)\n",
    "    if not gather: return b\n",
    "    def _inner(b, cpu, gather):\n",
    "        if b.ndim == 0: b=b[None]\n",
    "        b = torch.cat([b if i==self.rank else torch.full_like(b,-100) for i in range(self.world_size)])\n",
    "        return b if b.ndim > 0 else b.mean()\n",
    "    return apply(_inner,b,cpu,gather)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = TfmdDL(list(range(50)), bs=12, num_workers=2)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), (torch.arange(i*13, i*13+12)%50,torch.tensor([i*13+12])%50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = torch.utils.data.DataLoader(list(range(50)), batch_size=12, num_workers=2)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), (torch.arange(i*13, i*13+12)%50,torch.tensor([i*13+12])%50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = TfmdDL(list(zip(range(50),range(100,150))), bs=12, num_workers=4)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), [(torch.arange(i*13, i*13+12)%50,100+torch.arange(i*13, i*13+12)%50),\n",
    "                        ((torch.tensor([i*13+12])%50),100+torch.tensor([i*13+12])%50)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = TfmdDL(list(range(50)), bs=12, num_workers=2,drop_last=True)\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    test_eq(list(dl1), [torch.arange(i*13, i*13+12)%50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = TfmdDL(list(zip(range(12),range(100,112))), bs=12, num_workers=4)\n",
    "res,dls = [],[]\n",
    "for i in range(5): dls.append(DistributedDL(dl, i, 5))\n",
    "for b in zip(*dls):\n",
    "    for r in range(5):\n",
    "        d=L(dls[r].to_detach(b[r]))\n",
    "        test_eq(d.map(lambda x:(x!=-100).sum().item()),(3,3) if r!=4 else (0,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = TfmdDL(list(range(10)), bs=4, num_workers=2, shuffle=True)\n",
    "res = []\n",
    "for i in range(3):\n",
    "    dl1 = DistributedDL(dl, i, 3)\n",
    "    b  = list(dl1)[0]\n",
    "    bd = dl1.to_detach(b)\n",
    "    test_eq(b[:None if i<2 else 2],bd[4*i:4*(i+1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from fastai.callback.data import WeightedDL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "dl = WeightedDL(list(range(50)), bs=16, num_workers=2, shuffle=True,wgts=list(np.arange(50)>=25))\n",
    "res = []\n",
    "for i in range(4):\n",
    "    dl1 = DistributedDL(dl, i, 4)\n",
    "    res += list(dl1)[0].tolist()\n",
    "test(res,[25]*len(res),operator.ge)        # all res >=25\n",
    "test(res,[25]*len(res),lambda a,b: not (a<b)) # all res NOT < 25"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DistributedTrainer -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "_hidden_params = [\"mixed_precision\", \"fp16\", \"log_with\", \"logging_dir\", \"step_scheduler_with_optimizer\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DistributedTrainer(Callback):\n",
    "    \"Wrap `model` in `DistributedDataParallel` and `dls` in `DistributedDL`\"\n",
    "    order = 11\n",
    "    @delegates(Accelerator, but=_hidden_params)\n",
    "    def __init__(self,\n",
    "        sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`\n",
    "        **kwargs\n",
    "    ):\n",
    "        store_attr()\n",
    "        self.accelerator = Accelerator(**kwargs)\n",
    "    def before_fit(self):\n",
    "        self.learn.model = self.accelerator.prepare(\n",
    "            nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.sync_bn else self.model\n",
    "        )\n",
    "        self.old_dls = list(self.dls)\n",
    "        self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls]\n",
    "        if rank_distrib(): self.learn.logger=noop\n",
    "\n",
    "    def _wrap_dl(self, dl): return dl if isinstance(dl,DistributedDL) else DistributedDL(dl, device=self.learn.model.device)\n",
    "    def _backward(self): self.accelerator.backward(self.learn.loss_grad)\n",
    "    \n",
    "    def before_train(self):    self.learn.dl = self._wrap_dl(self.learn.dl)\n",
    "    def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl)\n",
    "    def after_fit(self): self.learn.model,self.learn.dls.loaders = self.learn.model.module,self.old_dls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@delegates(Accelerator, but=_hidden_params)\n",
    "def to_distributed(self: Learner,\n",
    "        sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`\n",
    "        **kwargs\n",
    "    ):\n",
    "    \"Add `AcceleratedTrainer` to a learner, and configures an Accelerator\"\n",
    "    self.add_cb(DistributedTrainer(sync_bn, **kwargs))\n",
    "    if rank_distrib(): self.remove_cb(ProgressCallback)\n",
    "    return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def detach_distributed(self: Learner):\n",
    "    \"Remove `DistributedTrainer` from a learner\"\n",
    "    if num_distrib() <=1: return self\n",
    "    self.remove_cb(DistributedTrainer)\n",
    "    if rank_distrib() and not hasattr(self, 'progress'): self.add_cb(ProgressCallback())\n",
    "    return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `distrib_ctx` context manager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "@contextmanager\n",
    "@delegates(Accelerator, but=_hidden_params)\n",
    "def distrib_ctx(self: Learner,\n",
    "        sync_bn=True, # Whether to replace all batch norm with `nn.SyncBatchNorm`\n",
    "        in_notebook=False, # Whether we are launching from a notebook or not\n",
    "        **kwargs\n",
    "   ):\n",
    "    \"A context manager to adapt a learner to train in distributed data parallel mode.\"\n",
    "    try: import accelerate\n",
    "    except ImportError as e: \n",
    "        e.args = [\"Accelerate is required. Install with `pip install accelerate`\"]\n",
    "        raise\n",
    "    # Adapt self to DistributedDataParallel, yield, and cleanup afterwards.\n",
    "    cleanup_dpg = False\n",
    "    try:\n",
    "        if in_notebook:\n",
    "            cuda_id = rank_distrib()\n",
    "            if not torch.distributed.is_initialized():\n",
    "                setup_distrib(cuda_id)\n",
    "                cleanup_dpg = torch.distributed.is_initialized()\n",
    "            if not rank_distrib(): print(\"Training Learner...\")\n",
    "        if num_distrib(): self.to_distributed(sync_bn, **kwargs)\n",
    "        yield self\n",
    "    finally:\n",
    "        self.detach_distributed()\n",
    "        if cleanup_dpg: teardown_distrib()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`distrib_ctx` prepares a learner to train in distributed data parallel mode. It assumes the script/code will either be ran through the command line via `accelerate launch` or through the `notebook_launcher` function from Accelerate. It also assumes that `accelerate` has been configured through either running `write_basic_config()` or calling `accelerate config` through the CLI and answering the prompts.\n",
    "\n",
    "Typical usage:\n",
    "\n",
    "```\n",
    "with learn.distrib_ctx(): learn.fit(.....)\n",
    "```\n",
    "\n",
    "It attaches a `DistributedTrainer` callback and `DistributedDL` data loader to  the learner, then executes `learn.fit(.....)`.  Upon exiting the context, it removes the `DistributedTrainer` and `DistributedDL`, and destroys any locally created distributed process group.  The process is still attached to the GPU though."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def rank0_first(func, *args, **kwargs):\n",
    "    \"Execute `func` in the Rank-0 process first, then in other ranks in parallel.\"\n",
    "    if args or kwargs: func = partial(func, *args, **kwargs)\n",
    "    dummy_l = Learner(DataLoaders(device='cpu'), nn.Linear(1,1), loss_func=lambda: 0)\n",
    "    with dummy_l.distrib_ctx():\n",
    "        if not rank_distrib(): res = func()\n",
    "        distrib_barrier()\n",
    "        if rank_distrib(): res = func()\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`rank0_first` calls `f()` in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, `f()` is called only once as expected.\n",
    "\n",
    "One application of `rank0_first()` is to make fresh downloads via `untar_data` safe in distributed training scripts launched by `python -m fastai.launch <script>`:\n",
    "\n",
    "<code>path = untar_data(URLs.IMDB)</code>\n",
    "\n",
    "becomes:\n",
    "\n",
    "<code>path = rank0_first(lambda: untar_data(URLs.IMDB))</code>\n",
    "\n",
    "Some learner factory methods may use `untar_data` to download pretrained models:\n",
    "\n",
    "<code>learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)</code>\n",
    "\n",
    "becomes:\n",
    "\n",
    "<code>learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))</code>\n",
    "\n",
    "Otherwise, multiple processes will download at the same time and corrupt the data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook Launcher\n",
    "\n",
    "Accelerate provides a [notebook_launcher](https://huggingface.co/docs/accelerate/launcher) functionality to let you keep using your Jupyter Notebook as you would, but train in a distributed setup!\n",
    "\n",
    "First, make sure accelerate is properly configured. You can either run `accelerate config` from the command line, or have an autofilled configuration setup by running in the first cell of your notebook:\n",
    "\n",
    "```python\n",
    "from accelerate.utils import write_basic_config\n",
    "write_basic_config()\n",
    "```\n",
    "After Accelerate is configured, to utilize the `notebook_launcher` functionality migrate your training into a function, and pass this to `notebook_launcher`, such as:\n",
    "\n",
    "```python\n",
    "---\n",
    "from fastai.vision.all import *\n",
    "from fastai.distributed import *\n",
    "\n",
    "def train():\n",
    "    set_seed(99, True)\n",
    "    path = untar_data(URLs.PETS)/'images'\n",
    "    dls = ImageDataLoaders.from_name_func(\n",
    "        path, get_image_files(path), valid_pct=0.2,\n",
    "        label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))\n",
    "    \n",
    "    learn = vision_learner(dls, resnet34, metrics=error_rate).to_fp16()\n",
    "    with learn.distrib_ctx(in_notebook=True):\n",
    "        learn.fine_tune(1)\n",
    "---\n",
    "from accelerate import notebook_launcher\n",
    "notebook_launcher(train, num_processes=2)\n",
    "---\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
